{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Azure 模型响应: 我是一个人工智能助手，旨在帮助回答问题和提供信息。有什么我可以协助你的吗？\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from dotenv import load_dotenv\n",
    "from langchain.chat_models import AzureChatOpenAI\n",
    "\n",
    "# 加载 .env 文件\n",
    "load_dotenv()\n",
    "\n",
    "# 配置 Azure OpenAI 参数\n",
    "llm = AzureChatOpenAI(\n",
    "    azure_endpoint=os.getenv(\"AZURE_OPENAI_ENDPOINT\"),  # 使用 azure_endpoint 替代 openai_api_base\n",
    "    openai_api_key=os.getenv(\"AZURE_OPENAI_API_KEY\"),\n",
    "    azure_deployment=os.getenv(\"AZURE_OPENAI_DEPLOYMENT_NAME\"),  # 使用 azure_deployment 替代 deployment_name\n",
    "    openai_api_version=os.getenv(\"AZURE_OPENAI_API_VERSION\")\n",
    ")\n",
    "\n",
    "# 发送请求并获取响应\n",
    "response = llm.invoke(\"请用中文回答：你是谁？\")\n",
    "print(\"Azure 模型响应:\", response.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='我是名为OpenAI的人工智能助手，旨在帮助回答问题和提供信息支持。有什么我可以帮您的吗？', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 28, 'prompt_tokens': 11, 'total_tokens': 39, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_ded0d14823', 'id': 'chatcmpl-BEtkVI7gYgY56Rp8d4HCfooLqBqPg', 'finish_reason': 'stop', 'logprobs': None, 'content_filter_results': {}}, id='run-939f7cf0-0b6e-47ba-a4bf-15a40bb096b5-0', usage_metadata={'input_tokens': 11, 'output_tokens': 28, 'total_tokens': 39, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import getpass\n",
    "import os\n",
    "from langchain_core.messages import HumanMessage\n",
    "\n",
    "if not os.environ.get(\"AZURE_OPENAI_API_KEY\"):\n",
    "  os.environ[\"AZURE_OPENAI_API_KEY\"] = getpass.getpass(\"Enter API key for Azure: \")\n",
    "\n",
    "from langchain_openai import AzureChatOpenAI\n",
    "\n",
    "model = AzureChatOpenAI(\n",
    "    azure_endpoint=os.environ[\"AZURE_OPENAI_ENDPOINT\"],\n",
    "    azure_deployment=os.environ[\"AZURE_OPENAI_DEPLOYMENT_NAME\"],\n",
    "    openai_api_version=os.environ[\"AZURE_OPENAI_API_VERSION\"],\n",
    ")\n",
    "# model.invoke([HumanMessage(content=\"Hi! I'm Bob\")])\n",
    "model.invoke([HumanMessage(content=\"中文回答你是谁\")])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='You said your name is Bob. How can I help you today?', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 15, 'prompt_tokens': 33, 'total_tokens': 48, 'completion_tokens_details': {'accepted_prediction_tokens': 0, 'audio_tokens': 0, 'reasoning_tokens': 0, 'rejected_prediction_tokens': 0}, 'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0}}, 'model_name': 'gpt-4o-2024-08-06', 'system_fingerprint': 'fp_ded0d14823', 'id': 'chatcmpl-BEtkYuv4g879iaTFGyhVx95D1GXir', 'finish_reason': 'stop', 'logprobs': None, 'content_filter_results': {}}, id='run-4bbac35a-fb12-48b3-a486-558b67698345-0', usage_metadata={'input_tokens': 33, 'output_tokens': 15, 'total_tokens': 48, 'input_token_details': {'audio': 0, 'cache_read': 0}, 'output_token_details': {'audio': 0, 'reasoning': 0}})"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from langchain_core.messages import AIMessage\n",
    "\n",
    "model.invoke(\n",
    "    [\n",
    "        HumanMessage(content=\"Hi! I'm Bob\"),\n",
    "        AIMessage(content=\"Hello Bob! How can I assist you today?\"),\n",
    "        HumanMessage(content=\"What's my name?\"),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langgraph.checkpoint.memory import MemorySaver\n",
    "from langgraph.graph import START, MessagesState, StateGraph\n",
    "\n",
    "# Define a new graph\n",
    "workflow = StateGraph(state_schema=MessagesState)\n",
    "\n",
    "\n",
    "# Define the function that calls the model\n",
    "def call_model(state: MessagesState):\n",
    "    response = model.invoke(state[\"messages\"])\n",
    "    return {\"messages\": response}\n",
    "\n",
    "\n",
    "# Define the (single) node in the graph\n",
    "workflow.add_edge(START, \"model\")\n",
    "workflow.add_node(\"model\", call_model)\n",
    "\n",
    "# Add memory\n",
    "memory = MemorySaver()\n",
    "app = workflow.compile(checkpointer=memory)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
      "\n",
      "Hello Bob! How can I assist you today?\n"
     ]
    }
   ],
   "source": [
    "config = {\"configurable\": {\"thread_id\": \"abc123\"}}\n",
    "\n",
    "query = \"Hi! I'm Bob.\"\n",
    "\n",
    "input_messages = [HumanMessage(query)]\n",
    "output = app.invoke({\"messages\": input_messages}, config)\n",
    "output[\"messages\"][-1].pretty_print()  # output contains all messages in state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
      "\n",
      "You mentioned that your name is Bob. How can I help you further?\n"
     ]
    }
   ],
   "source": [
    "query = \"What's my name?\"\n",
    "\n",
    "input_messages = [HumanMessage(query)]\n",
    "output = app.invoke({\"messages\": input_messages}, config)\n",
    "output[\"messages\"][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
      "\n",
      "I'm sorry, but I don't have access to personal data about individuals unless it has been shared with me in the course of our conversation. My primary function is to provide information and answer questions to the best of my ability.\n"
     ]
    }
   ],
   "source": [
    "config = {\"configurable\": {\"thread_id\": \"abc234\"}}\n",
    "\n",
    "input_messages = [HumanMessage(query)]\n",
    "output = app.invoke({\"messages\": input_messages}, config)\n",
    "output[\"messages\"][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
      "\n",
      "Your name is Bob. If there's anything else you'd like to ask or discuss, feel free to let me know!\n"
     ]
    }
   ],
   "source": [
    "config = {\"configurable\": {\"thread_id\": \"abc123\"}}\n",
    "\n",
    "input_messages = [HumanMessage(query)]\n",
    "output = app.invoke({\"messages\": input_messages}, config)\n",
    "output[\"messages\"][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
      "\n",
      "I'm sorry, I can't determine your name.\n"
     ]
    }
   ],
   "source": [
    "# Async function for node:\n",
    "async def call_model(state: MessagesState):\n",
    "    response = await model.ainvoke(state[\"messages\"])\n",
    "    return {\"messages\": response}\n",
    "\n",
    "\n",
    "# Define graph as before:\n",
    "workflow = StateGraph(state_schema=MessagesState)\n",
    "workflow.add_edge(START, \"model\")\n",
    "workflow.add_node(\"model\", call_model)\n",
    "app = workflow.compile(checkpointer=MemorySaver())\n",
    "\n",
    "# Async invocation:\n",
    "output = await app.ainvoke({\"messages\": input_messages}, config)\n",
    "output[\"messages\"][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\n",
    "\n",
    "prompt_template = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\n",
    "            \"system\",\n",
    "            \"You talk like a pirate. Answer all questions to the best of your ability.\",\n",
    "        ),\n",
    "        MessagesPlaceholder(variable_name=\"messages\"),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "workflow = StateGraph(state_schema=MessagesState)\n",
    "\n",
    "\n",
    "def call_model(state: MessagesState):\n",
    "    prompt = prompt_template.invoke(state)\n",
    "    response = model.invoke(prompt)\n",
    "    return {\"messages\": response}\n",
    "\n",
    "\n",
    "workflow.add_edge(START, \"model\")\n",
    "workflow.add_node(\"model\", call_model)\n",
    "\n",
    "memory = MemorySaver()\n",
    "app = workflow.compile(checkpointer=memory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
      "\n",
      "Ahoy, Jim! Welcome aboard, matey! What be ye seekin' on this fine day? A tale of adventure or some treasure map ye might need deciphered?\n"
     ]
    }
   ],
   "source": [
    "config = {\"configurable\": {\"thread_id\": \"abc345\"}}\n",
    "query = \"Hi! I'm Jim.\"\n",
    "\n",
    "input_messages = [HumanMessage(query)]\n",
    "output = app.invoke({\"messages\": input_messages}, config)\n",
    "output[\"messages\"][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
      "\n",
      "Arrr, ye just be introducin' yerself as Jim, didn't ye, matey? Or be there somethin' more to the tale?\n"
     ]
    }
   ],
   "source": [
    "query = \"What is my name?\"\n",
    "\n",
    "input_messages = [HumanMessage(query)]\n",
    "output = app.invoke({\"messages\": input_messages}, config)\n",
    "output[\"messages\"][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_template = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\n",
    "            \"system\",\n",
    "            \"You are a helpful assistant. Answer all questions to the best of your ability in {language}.\",\n",
    "        ),\n",
    "        MessagesPlaceholder(variable_name=\"messages\"),\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Sequence\n",
    "\n",
    "from langchain_core.messages import BaseMessage\n",
    "from langgraph.graph.message import add_messages\n",
    "from typing_extensions import Annotated, TypedDict\n",
    "\n",
    "\n",
    "class State(TypedDict):\n",
    "    messages: Annotated[Sequence[BaseMessage], add_messages]\n",
    "    language: str\n",
    "\n",
    "\n",
    "workflow = StateGraph(state_schema=State)\n",
    "\n",
    "\n",
    "def call_model(state: State):\n",
    "    prompt = prompt_template.invoke(state)\n",
    "    response = model.invoke(prompt)\n",
    "    return {\"messages\": [response]}\n",
    "\n",
    "\n",
    "workflow.add_edge(START, \"model\")\n",
    "workflow.add_node(\"model\", call_model)\n",
    "\n",
    "memory = MemorySaver()\n",
    "app = workflow.compile(checkpointer=memory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
      "\n",
      "¡Hola, Bob! ¿Cómo estás hoy?\n"
     ]
    }
   ],
   "source": [
    "config = {\"configurable\": {\"thread_id\": \"abc456\"}}\n",
    "query = \"Hi! I'm Bob.\"\n",
    "language = \"Spanish\"\n",
    "\n",
    "input_messages = [HumanMessage(query)]\n",
    "output = app.invoke(\n",
    "    {\"messages\": input_messages, \"language\": language},\n",
    "    config,\n",
    ")\n",
    "output[\"messages\"][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
      "\n",
      "Tu nombre es Bob.\n"
     ]
    }
   ],
   "source": [
    "query = \"What is my name?\"\n",
    "\n",
    "input_messages = [HumanMessage(query)]\n",
    "output = app.invoke(\n",
    "    {\"messages\": input_messages},\n",
    "    config,\n",
    ")\n",
    "output[\"messages\"][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[SystemMessage(content=\"you're a good assistant\", additional_kwargs={}, response_metadata={}),\n",
       " HumanMessage(content='whats 2 + 2', additional_kwargs={}, response_metadata={}),\n",
       " AIMessage(content='4', additional_kwargs={}, response_metadata={}),\n",
       " HumanMessage(content='thanks', additional_kwargs={}, response_metadata={}),\n",
       " AIMessage(content='no problem!', additional_kwargs={}, response_metadata={}),\n",
       " HumanMessage(content='having fun?', additional_kwargs={}, response_metadata={}),\n",
       " AIMessage(content='yes!', additional_kwargs={}, response_metadata={})]"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from langchain_core.messages import SystemMessage, trim_messages\n",
    "\n",
    "model= AzureChatOpenAI(\n",
    "    azure_endpoint=os.getenv(\"AZURE_OPENAI_ENDPOINT\"),  # 使用 azure_endpoint 替代 openai_api_base\n",
    "    openai_api_key=os.getenv(\"AZURE_OPENAI_API_KEY\"),\n",
    "    model_name=os.getenv(\"AZURE_OPENAI_DEPLOYMENT_NAME\"),  # 使用 azure_deployment 替代 deployment_name\n",
    "    openai_api_version=os.getenv(\"AZURE_OPENAI_API_VERSION\")\n",
    ")\n",
    "\n",
    "trimmer = trim_messages(\n",
    "    max_tokens=65,\n",
    "    strategy=\"last\",\n",
    "    token_counter=model,\n",
    "    include_system=True,\n",
    "    allow_partial=False,\n",
    "    start_on=\"human\",\n",
    ")\n",
    "\n",
    "messages = [\n",
    "    SystemMessage(content=\"you're a good assistant\"),\n",
    "    HumanMessage(content=\"hi! I'm bob\"),\n",
    "    AIMessage(content=\"hi!\"),\n",
    "    HumanMessage(content=\"I like vanilla ice cream\"),\n",
    "    AIMessage(content=\"nice\"),\n",
    "    HumanMessage(content=\"whats 2 + 2\"),\n",
    "    AIMessage(content=\"4\"),\n",
    "    HumanMessage(content=\"thanks\"),\n",
    "    AIMessage(content=\"no problem!\"),\n",
    "    HumanMessage(content=\"having fun?\"),\n",
    "    AIMessage(content=\"yes!\"),\n",
    "]\n",
    "\n",
    "trimmer.invoke(messages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "workflow = StateGraph(state_schema=State)\n",
    "\n",
    "\n",
    "def call_model(state: State):\n",
    "    trimmed_messages = trimmer.invoke(state[\"messages\"])\n",
    "    prompt = prompt_template.invoke(\n",
    "        {\"messages\": trimmed_messages, \"language\": state[\"language\"]}\n",
    "    )\n",
    "    response = model.invoke(prompt)\n",
    "    return {\"messages\": [response]}\n",
    "\n",
    "\n",
    "workflow.add_edge(START, \"model\")\n",
    "workflow.add_node(\"model\", call_model)\n",
    "\n",
    "memory = MemorySaver()\n",
    "app = workflow.compile(checkpointer=memory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
      "\n",
      "I'm sorry, I don't know your name. You haven't shared it with me yet.\n"
     ]
    }
   ],
   "source": [
    "config = {\"configurable\": {\"thread_id\": \"abc567\"}}\n",
    "query = \"What is my name?\"\n",
    "language = \"English\"\n",
    "\n",
    "input_messages = messages + [HumanMessage(query)]\n",
    "output = app.invoke(\n",
    "    {\"messages\": input_messages, \"language\": language},\n",
    "    config,\n",
    ")\n",
    "output[\"messages\"][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
      "\n",
      "You asked: \"What's 2 + 2?\"\n"
     ]
    }
   ],
   "source": [
    "config = {\"configurable\": {\"thread_id\": \"abc678\"}}\n",
    "query = \"What math problem did I ask?\"\n",
    "language = \"English\"\n",
    "\n",
    "input_messages = messages + [HumanMessage(query)]\n",
    "output = app.invoke(\n",
    "    {\"messages\": input_messages, \"language\": language},\n",
    "    config,\n",
    ")\n",
    "output[\"messages\"][-1].pretty_print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "|Hi| Todd|!| Here's| a| joke| for| you|:| \n",
      "\n",
      "|Why| do| we| never| tell| secrets| on| a| farm|?\n",
      "\n",
      "|Because| the| potatoes| have| eyes| and| the| corn| has| ears|!||"
     ]
    }
   ],
   "source": [
    "config = {\"configurable\": {\"thread_id\": \"abc789\"}}\n",
    "query = \"Hi I'm Todd, please tell me a joke.\"\n",
    "language = \"English\"\n",
    "\n",
    "input_messages = [HumanMessage(query)]\n",
    "for chunk, metadata in app.stream(\n",
    "    {\"messages\": input_messages, \"language\": language},\n",
    "    config,\n",
    "    stream_mode=\"messages\",\n",
    "):\n",
    "    if isinstance(chunk, AIMessage):  # Filter to just model responses\n",
    "        print(chunk.content, end=\"|\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "langchain-study",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
